#!/usr/bin/env python
# encoding: utf-8
 
import socket
from hashlib import sha1
from random import randint
from struct import unpack
from socket import inet_ntoa
from threading import Timer, Thread
from time import sleep
from collections import deque
 
from bencode import bencode, bdecode
 
BOOTSTRAP_NODES = (
  ("router.bittorrent.com", 6881),
  ("dht.transmissionbt.com", 6881),
  ("router.utorrent.com", 6881)
)
TID_LENGTH = 2
RE_JOIN_DHT_INTERVAL = 3
TOKEN_LENGTH = 2
 
 
def entropy(length):
  return "".join(chr(randint(0, 255)) for _ in xrange(length))
 
 
def random_id():
  h = sha1()
  h.update(entropy(20))
  return h.digest()
 
 
def decode_nodes(nodes):
  n = []
  length = len(nodes)
  if (length % 26) != 0:
    return n
 
  for i in range(0, length, 26):
    nid = nodes[i:i+20]
    ip = inet_ntoa(nodes[i+20:i+24])
    port = unpack("!H", nodes[i+24:i+26])[0]
    n.append((nid, ip, port))
 
  return n
 
 
def timer(t, f):
  Timer(t, f).start()
 
 
def get_neighbor(target, nid, end=10):
  return target[:end]+nid[end:]
 
 
class KNode(object):
 
  def __init__(self, nid, ip, port):
    self.nid = nid
    self.ip = ip
    self.port = port
 
 
class DHTClient(Thread):
 
  def __init__(self, max_node_qsize):
    Thread.__init__(self)
    self.setDaemon(True)
    self.max_node_qsize = max_node_qsize
    self.nid = random_id()
    self.nodes = deque(maxlen=max_node_qsize)
 
  def send_krpc(self, msg, address):
    try:
      self.ufd.sendto(bencode(msg), address)
    except Exception:
      pass
 
  def send_find_node(self, address, nid=None):
    nid = get_neighbor(nid, self.nid) if nid else self.nid
    tid = entropy(TID_LENGTH)
    msg = {
      "t": tid,
      "y": "q",
      "q": "find_node",
      "a": {
        "id": nid,
        "target": random_id()
      }
    }
    self.send_krpc(msg, address)
 
  def join_DHT(self):
    for address in BOOTSTRAP_NODES:
      self.send_find_node(address)
 
  def re_join_DHT(self):
    if len(self.nodes) == 0:
      self.join_DHT()
    timer(RE_JOIN_DHT_INTERVAL, self.re_join_DHT)
 
  def auto_send_find_node(self):
    wait = 1.0 / self.max_node_qsize
    while True:
      try:
        node = self.nodes.popleft()
        self.send_find_node((node.ip, node.port), node.nid)
      except IndexError:
        pass
      sleep(wait)
 
  def process_find_node_response(self, msg, address):
    nodes = decode_nodes(msg["r"]["nodes"])
    for node in nodes:
      (nid, ip, port) = node
      if len(nid) != 20: continue
      if ip == self.bind_ip: continue
      if port < 1 or port > 65535: continue
      n = KNode(nid, ip, port)
      self.nodes.append(n)
 
 
class DHTServer(DHTClient):
 
  def __init__(self, master, bind_ip, bind_port, max_node_qsize):
    DHTClient.__init__(self, max_node_qsize)
 
    self.master = master
    self.bind_ip = bind_ip
    self.bind_port = bind_port
 
    self.process_request_actions = {
      "get_peers": self.on_get_peers_request,
      "announce_peer": self.on_announce_peer_request,
    }
 
    self.ufd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
    self.ufd.bind((self.bind_ip, self.bind_port))
 
    timer(RE_JOIN_DHT_INTERVAL, self.re_join_DHT)
 
 
  def run(self):
    self.re_join_DHT()
    while True:
      try:
        (data, address) = self.ufd.recvfrom(65536)
        msg = bdecode(data)
        self.on_message(msg, address)
      except Exception:
        pass
 
  def on_message(self, msg, address):
    try:
      if msg["y"] == "r":
        if msg["r"].has_key("nodes"):
          self.process_find_node_response(msg, address)
      elif msg["y"] == "q":
        try:
          self.process_request_actions[msg["q"]](msg, address)
        except KeyError:
          self.play_dead(msg, address)
    except KeyError:
      pass
 
  def on_get_peers_request(self, msg, address):
    try:
      infohash = msg["a"]["info_hash"]
      tid = msg["t"]
      nid = msg["a"]["id"]
      token = infohash[:TOKEN_LENGTH]
      msg = {
        "t": tid,
        "y": "r",
        "r": {
          "id": get_neighbor(infohash, self.nid),
          "nodes": "",
          "token": token
        }
      }
      self.send_krpc(msg, address)
    except KeyError:
      pass
 
  def on_announce_peer_request(self, msg, address):
    try:
      infohash = msg["a"]["info_hash"]
      #print msg["a"]
      tname = msg["a"]["name"]
      token = msg["a"]["token"]
      nid = msg["a"]["id"]
      tid = msg["t"]
 
      if infohash[:TOKEN_LENGTH] == token:
        if msg["a"].has_key("implied_port") and msg["a"]["implied_port"] != 0:
          port = address[1]
        else:
          port = msg["a"]["port"]
          if port < 1 or port > 65535: return
        self.master.log(infohash, (address[0], port),tname)
    except Exception:
      pass
    finally:
      self.ok(msg, address)
 
  def play_dead(self, msg, address):
    try:
      tid = msg["t"]
      msg = {
        "t": tid,
        "y": "e",
        "e": [202, "Server Error"]
      }
      self.send_krpc(msg, address)
    except KeyError:
      pass
 
  def ok(self, msg, address):
    try:
      tid = msg["t"]
      nid = msg["a"]["id"]
      msg = {
        "t": tid,
        "y": "r",
        "r": {
          "id": get_neighbor(nid, self.nid)
        }
      }
      self.send_krpc(msg, address)
    except KeyError:
      pass
 
 
class Master(object):
  def log(self, infohash,address=None,tname=None):
    hexinfohash = infohash.encode("hex")
    print "info_hash is: %s,name is: %s from %s:%s" % (
      hexinfohash,tname, address[0], address[1]
  )
    print "magnet:?xt=urn:btih:%s&dn=%s" % (hexinfohash, tname)
 
 
# using example
if __name__ == "__main__":
  # max_node_qsize bigger, bandwith bigger, speed higher
  dht = DHTServer(Master(), "0.0.0.0", 6882, max_node_qsize=200)
  dht.start()
  dht.auto_send_find_node()